# -*- coding: utf-8 -*-

"""
This file define a function that takes an image or a list of images as inputs, and produce a predicted probability of this post talking about protest, based on stage-1 or stage-2.

"""



#%%

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import os
import numpy as np
from keras.models import load_model
from keras.models import model_from_json
import skimage.io as io
import time
# import scikits.image.io as io

import tensorflow as tf
global graph
graph = tf.get_default_graph()


## read model files
json_file = open('../modelfiles/image.json', 'r')
loaded_model_json = json_file.read()
json_file.close()

model = model_from_json(loaded_model_json)
model.load_weights("../modelfiles/weights_image.h5")


def rself(x):
	return io.imread(x, as_grey = True)


def predict_image(filepath, prefix = ""):
	""" give a filepath containing images
	produce a list of predicted probabilities of images containing protests.

	Args:
		filepath (str): a string of images, separated by ;
		prefix (str): for each image, add the absolute path to it so that the correct locatio on disk can be found
	Returns:
		the largest predicted probability of the given list.

	"""

	filepath = filepath.split(";")
	## add absolute disk path to the images
	filepath = [x.strip() for x in filepath]
	filepath = [prefix + x for x in filepath]


	# print "filepath", filepath
	## download images from s3 to 

    ## fetch images from aws
 #    for eachfile in filepath:
	#     os.system("aws s3 cp s3://protest-images/%s"%eachfile)

	# time.sleep(1)

	imgs = [io.imread(i) for i in filepath if os.path.exists(i)]
	if len(imgs):
		imgs = np.array(imgs)
		imgs = imgs  / 255.0
		with graph.as_default():	
			y = model.predict(imgs).flatten().tolist()
		return max(y)
	else: 
		return np.nan


if __name__ == '__main__':
	print predict_image("/Users/han/Codes/Jen Pan/image_classification/testimage/tagged_images_resize/3bfb171bjw1essd1uyo6bj20bc0c9t9e.jpg;/Users/han/Codes/Jen Pan/image_classification/testimage/tagged_images_resize/01e6dae0gw1eopjcyd45yj20gh0be75e.jpg")


